# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import logging
import re
import time
import cfg

from datetime import datetime
from flask import current_app, request, url_for, g
from flask_sqlalchemy import Pagination
from sqlalchemy import distinct, select, func, outerjoin, desc, or_
from sqlalchemy.orm import aliased, Bundle, joinedload
from sqlalchemy_fulltext import FullTextSearch
import sqlalchemy_fulltext.modes as FullTextMode
from werkzeug.routing import RequestRedirect

try:
  import urllib2
except ImportError:
  import urllib.request as urllib2

from app.exceptions import InvalidIdentifierException
from data.database import DEFAULT_DATABASE
from data.models import Nvd, RepositoryFilesSchema, Vulnerability, \
  RepositoryFiles, RepositoryFileComments, VulnerabilityGitCommits, User, \
  RepositoryFileMarkers, Cpe
from lib.vcs_management import getVcsHandler, VULN_ID_PLACEHOLDER, HASH_PLACEHOLDER, PATH_PLACEHOLDER

db = DEFAULT_DATABASE


def getVulnerability(filter_by):
  if not isinstance(filter_by, dict):
    current_app.logger.error('Received invalid filter.')
    return None

  if 'cve_id' in filter_by:
    vulnerability = Vulnerability.get_by_cve_id(filter_by['cve_id'])
  elif 'commit_hash' in filter_by:
    vulnerability = Vulnerability.get_by_commit_hash(filter_by['commit_hash'])
  else:
    current_app.logger.error('Invalid filter option received.')
    return None
  return vulnerability


def _measure_exution_time(label):

  def decorator(func):

    def wrapper(*args, **kwargs):
      start = time.time()
      res = func(*args, **kwargs)
      end = time.time()

      print('[{}] {}s elapsed'.format(label, end - start))
      return res

    return wrapper

  return decorator


class VulnerabilityView:

  def __init__(self, vulnerability, nvd):
    self.vulnerability = vulnerability
    self.nvd = nvd

    if not self.vulnerability and not self.nvd:
      return

    self.id = None
    self.cve_id = None
    self.date_created = None
    self.vcdb_exists = False
    self.comment = None
    self.master_commit = None
    self.relevant_files = []
    self.products = None
    self.score = None
    self.cwe_id = None
    self.cwe_name = None
    self.known_patches = []
    self.link_references = []
    # TODO: Implement oss flag.
    self.is_oss = False
    self.annotated = False

    self.parent_commit = None
    self.master_commit_stats = None
    self.master_commit_message = None
    self.master_commit_files = None
    self.master_commit_date = None

    if self.vulnerability:
      self.id = self.vulnerability.id
      self.date_created = self.vulnerability.date_created.strftime('%Y-%m-%d')
      self.vcdb_exists = True
      self.master_commit = self.vulnerability.master_commit

      for repo_file in self.master_commit.repository_files:
        relevant_file_path = './' + repo_file.file_path
        self.relevant_files.append(relevant_file_path)

      if self.master_commit.tree_cache:
        tree_cache = json.loads(self.master_commit.tree_cache)
        if 'commit' in tree_cache:
          commit_data = tree_cache['commit']
          self.parent_commit = commit_data['parent_hash']
          commit_date = datetime.fromtimestamp(commit_data['date'])
          self.master_commit_date = commit_date.strftime('%Y-%m-%d')
          self.master_commit_message = commit_data['message']
          self.master_commit_stats = commit_data['stats']
          self.master_commit_files = commit_data['files']

          #relevant_files_copy = [f for f in self.relevant_files]
          for patched_files in self.master_commit_files:
            relevant_file_path = './' + patched_files['path']
            if relevant_file_path not in self.relevant_files:
              patch_stats = '(' + patched_files['status']
              if patched_files['additions'] > 0:
                patch_stats += ', +' + str(patched_files['additions'])
              if patched_files['deletions'] > 0:
                patch_stats += ', -' + str(patched_files['deletions'])
              patch_stats += ')'
              self.relevant_files.append(relevant_file_path + ' ' + patch_stats)

      for commit in self.vulnerability.commits:
        if commit is self.master_commit:
          continue
        self.known_patches.append(commit.commit_link)
      if len(self.vulnerability.comment) > 0:
        self.comment = self.vulnerability.comment

      if self.vulnerability.master_commit.num_comments > 0:
        self.annotated = True

    # Note: if the vulnerability is linked to a CVE nvd must always be set!
    if self.nvd:
      # Always set the following attributes.
      self.id = self.nvd.cve_id
      self.cve_id = self.nvd.cve_id
      self.cwe_id = self.nvd.cwe_id
      if self.nvd.cwe:
        cwe_entry = self.nvd.cwe
        if isinstance(self.nvd.cwe, list) and len(self.nvd.cwe) > 0:
          cwe_entry = cwe_entry[0]
        self.cwe_name = cwe_entry.cwe_name

      self.date_created = self.nvd.published_date.strftime('%Y-%m-%d')
      if self.nvd.has_patch():
        all_patches = self.nvd.get_patches()
        remaining_patches = set(all_patches) - set(self.known_patches)
        # TODO: Refactor the redundant master commit removal below.
        if self.master_commit:
          remaining_patches -= set([self.master_commit.commit_link])
        self.known_patches += remaining_patches

      all_links = self.nvd.get_links()
      remaining_links = set(all_links) - set(self.known_patches)
      if self.master_commit:
        remaining_links -= set([self.master_commit.commit_link])
      self.link_references += list(remaining_links)

      # Only set if not already set (avoid accidental overwrites).
      if self.comment is None:
        self.comment = self.nvd.summary
      if self.products is None:
        self.products = ', '.join(self.nvd.get_products())
      if self.score is None:
        self.score = self.nvd.score


class VulnViewPaginationObjectWrapper(Pagination):
  """
  A Flask SQLAlchemy Pagination object wrapper class which wraps Vuln/Nvd items
  inside a VulnView.
  """

  def __init__(self, paginationObject):
    """
    :param paginationObject: A Flask SQLalchemy Pagination object.
    """
    self.__class__ = type(paginationObject.__class__.__name__,
                          (self.__class__, paginationObject.__class__), {})
    self.__dict__ = paginationObject.__dict__

    self._wrap_items()

  def _wrap_items(self):
    """
    Wraps Vulnerability/Nvd items with the VulnerabilityView class.
    :return:
    """
    new_items = []
    if len(self.items) == 0:
      return

    first_element = self.items[0]
    if isinstance(first_element, Nvd):
      for nvd in self.items:
        vuln_view = VulnerabilityView(None, nvd)
        new_items.append(vuln_view)
    elif isinstance(first_element, Vulnerability):
      for vulnerability in self.items:
        vuln_view = VulnerabilityView(vulnerability, None)
        new_items.append(vuln_view)
    else:
      for [vulnerability, nvd] in self.items:
        vuln_view = VulnerabilityView(vulnerability, nvd)
        new_items.append(vuln_view)

    self.items = new_items

  def prev(self, error_out=False):
    pass

  def next(self, error_out=False):
    pass


class VulncodeDB:
  #
  # def getEntriesWithPatch(self):
  #   cve_entries = select([Nvd, Vulnerability]).select_from(
  #       outerjoin(Nvd, Vulnerability)).apply_labels() #.order_by(desc(Vulnerability.date_created))
  #
  #
  #   # TODO: find performant way to simulate "union" or to deduplicate the
  #   # entries from both data sets... This still takes ~800ms :/...
  #   query_union = db.session.query(Nvd, Vulnerability).select_entity_from(
  #       vcdb_entries.union_all(cve_entries))
  #   #
  #   #cve_entries = db.session.query(Nvd, Vulnerability).outerjoin(Vulnerability)
  #   #query_union = vcdb_entries.union(cve_entries)
  #
  #   #entry = query_union[0]
  #
  #   #entry = entries[0]
  #
  #   #entries = Nvd.get_all_by_link_regex(cfg.PATCH_REGEX)
  #   return query_union

  #@_measure_exution_time('Vulncode-DB Constructor')
  def __init__(self):
    self.keyword = None
    # TODO: Look into neabling this once public contributions are enabled.
    #self.top_contributors = []
    #self.fetch_top_contributors()

    self.vcdb_entries = db.session.query(Vulnerability, Nvd).select_from(
        outerjoin(Vulnerability, Nvd)).options(
            joinedload(Nvd.cpes).load_only(Cpe.product)).order_by(
                desc(Vulnerability.date_created))

    self.nvd_entries = db.session.query(Nvd).outerjoin(Vulnerability).options(
        joinedload(Nvd.cpes).load_only(
            Cpe.product)).filter(Vulnerability.cve_id == None).order_by(
                desc(Nvd.created_at))

    self.keyword = request.args.get('keyword', None, type=str)

    apply_filter = None
    if self.keyword:
      # TODO: Make the filtering work with fulltext search as well.
      if VulnerabilityDetails.is_cve_id(self.keyword):
        apply_filter = or_(False, Nvd.cve_id == self.keyword)
      elif VulnerabilityDetails.is_vcdb_id(self.keyword):
        apply_filter = or_(False, Vulnerability.id == self.keyword)
      else:
        escaped_keyword = re.sub('[\W]+', ' ', self.keyword)
        # Attention: We can't use FullText search here because of some buggy
        # Mysql 5.7 behavior (using FullText on Join results seems is doing bad
        # things. We might need to apply the filter before joining below.
        # apply_filter = or_(
        #     FullTextSearch(escaped_keyword, Nvd, FullTextMode.BOOLEAN),
        #     FullTextSearch(escaped_keyword, Vulnerability, FullTextMode.BOOLEAN))
        apply_filter = or_(
            Nvd.summary.like('%' + escaped_keyword + '%'),
            Vulnerability.comment.like('%' + escaped_keyword + '%'))

      # TODO: add product search support.
      #apply_filter = or_(apply_filter, Cpe.product == keyword)

    if apply_filter is not None:
      self.vcdb_entries = self.vcdb_entries.filter(apply_filter)
      self.nvd_entries = self.nvd_entries.filter(apply_filter)

    per_page = 7
    vcdb_page = request.args.get('vcdb_p', 1, type=int)
    self.vcdb_pagination = self.vcdb_entries.paginate(
        vcdb_page, per_page=per_page)
    self.vcdb_pagination = VulnViewPaginationObjectWrapper(self.vcdb_pagination)

    nvd_page = request.args.get('nvd_p', 1, type=int)
    self.nvd_pagination = self.nvd_entries.paginate(nvd_page, per_page=per_page)
    self.nvd_pagination = VulnViewPaginationObjectWrapper(self.nvd_pagination)

  # @_measure_exution_time('TOP CONTRIB')
  def fetch_top_contributors(self):
    #TODO: count number of contributions to vulnerabilities instead of single annotations
    num_comments = self.get_annotation_query(RepositoryFileComments)
    num_markers = self.get_annotation_query(RepositoryFileMarkers)
    num_both = num_comments.c.count + num_markers.c.count
    self.top_contributors = db.session.query(
        User,
        func.coalesce(num_comments.c.count, 0).label('num_comments'),
        func.coalesce(num_markers.c.count, 0).label('num_markers')
      ) \
      .outerjoin(num_comments, num_comments.c.creator_id==User.id) \
      .outerjoin(num_markers, num_markers.c.creator_id==User.id) \
      .filter(num_both > 0)\
      .order_by(num_both.desc()).limit(10).all()

  @staticmethod
  def get_annotation_query(model):
    return db.session.query(
        model.creator_id.label('creator_id'),
        func.count(1).label('count')).filter_by(active=True).group_by(
            model.creator_id).subquery()

  # @_measure_exution_time('NEW 1')
  # def _new_method(self):
  #   files_with_comments = select(
  #       [distinct(RepositoryFileComments.repository_file_id)])
  #   commits_with_comments = db.session.query(
  #       distinct(RepositoryFiles.commit_id)).filter(
  #           RepositoryFiles.id.in_(files_with_comments))
  #   base_query = Vulnerability.query.join(Vulnerability.commits)
  #   self.annotated_entries = base_query.filter(
  #       VulnerabilityGitCommits.id.in_(commits_with_comments)).all()
  #   #self.empty_entries = base_query.filter(
  #   #~VulnerabilityGitCommits.id.in_(commits_with_comments)).all()

  # @_measure_exution_time('NEW 2')
  # def _new_method2(self):
  #   base_query = Vulnerability.query.join(
  #       Vulnerability.commits, VulnerabilityGitCommits.repository_files,
  #       RepositoryFiles.comments)
  #   # print base_query.having(func.count(RepositoryFileComments.id) > 0)
  #   self.annotated_entries = base_query.having(
  #       func.count(RepositoryFileComments.id) > 0).all()
  #   self.empty_entries = base_query.having(
  #       func.count(RepositoryFileComments.id) == 0).all()

  # @_measure_exution_time('OLD')
  # def _old_method(self):
  #   self.annotated_entries = []
  #   self.empty_entries = []
  #   all_entries = Vulnerability.query.all()
  #
  #   for entry in all_entries:
  #     master_commit = entry.commits[0]
  #     if master_commit.num_comments > 0:
  #       self.annotated_entries.append(entry)
  #     else:
  #       self.empty_entries.append(entry)


class VulnerabilityDetails:

  def __init__(self, vuln_id=None):
    self.suggested_id = vuln_id
    self.id = None
    self.vcdb_id = None
    self.cve_id = None
    self.commit_link = None
    self.commit_hash = None
    self.repo_url = None
    self.repo_name = None
    self._vulnerability = None
    self._nvd_data = None
    self.tree_url = None
    self.file_url = None
    self.file_provider_url = None
    self.file_ref_provider_url = None
    self.vulnerability_view = None

    self.populate_from_request()
    logging.debug('Loaded vulnerability details %r', self)

  def __repr__(self):
    args = [
        '{}={!r}'.format(name, value) for name, value in inspect.getmembers(
            self, lambda m: not inspect.isfunction(m) and not inspect.ismethod(
                m)) if not name.startswith('_')
    ]
    return '{}({})'.format(type(self).__name__, ', '.join(args))

  def update_details(self):
    """
    Updates the database with pending vuln + dependencies modifications.
    """
    db.session.add(self._vulnerability)
    db.session.commit()

  def get_nvd_entry(self):
    return self._nvd_data

  def populate_from_request(self):
    if not self.suggested_id:
      self.suggested_id = request.args.get('id', None, type=str)
    if not self.suggested_id:
      self.suggested_id = request.form.get('id', None, type=str)

    self.id = None
    self.vcdb_id = request.args.get('vcdb_id', None, type=str)
    self.cve_id = request.args.get('cve_id', None, type=str)
    self.commit_link = request.form.get('commit_link', None, type=str)
    self.commit_hash = request.form.get('commit_hash', None, type=str)
    self.repo_url = request.form.get('repo_url', None, type=str)
    self.repo_name = None
    self._vulnerability = None
    self._nvd_data = None
    self.file_provider_url = None
    self.file_ref_provider_url = None
    self.vulnerability_view = None

    # The suggested id will overwrite other identifiers accordingly.
    if self.suggested_id:
      if self.is_cve_id(self.suggested_id):
        self.cve_id = self.suggested_id
        logging.debug('Suggested id %r recognized as CVE', self.suggested_id)
      elif self.is_vcdb_id(self.suggested_id):
        self.vcdb_id = self.suggested_id
        logging.debug('Suggested id %r recognized as VCDB_ID',
                      self.suggested_id)
      if request.method == 'POST':
        # Handle more complex IDs only via POST.
        if self.is_commit_link(self.suggested_id):
          self.commit_link = self.suggested_id
          logging.debug('Suggested id %r recognized as commit link',
                        self.suggested_id)
        elif self.is_repo_data(self.suggested_id):
          repo_data = self.suggested_id.split('||')
          self.repo_url = repo_data[0]
          self.commit_hash = repo_data[1]
          logging.debug('Suggested id %r recognized as raw repo link',
                        self.suggested_id)

    self._fetch_data()

  def populate_from_model(self, model):
    self._vulnerability = model
    self.cve_id = model.cve_id
    self.vcdb_id = model.id
    self.commit_hash = model.master_commit.commit_hash
    self.commit_link = model.master_commit.commit_link
    self.repo_url = model.master_commit.repo_url
    self.repo_name = model.master_commit.repo_name

  def validate(self):
    self._set_id()
    if not self.id:
      raise InvalidIdentifierException(
          'Please provide a valid CVE ID or Git commit link.')

    #if request.path is '/vuln':
    #  if self.cve_id or self.vcdb_id:
    #    pass
    #    #use_endpoint = 'vuln.vuln_view'

    # Always redirect to the most simple URL.
    if request.method == 'GET':
      if not self.suggested_id or self.suggested_id != self.id:
        raise RequestRedirect('/' + str(self.id))

  def _set_id(self):
    """Sets the most recognizable ID according to a priority list."""
    #repo_data = None
    #if self.repo_url and self.commit_hash:
    #  repo_data = self.repo_url + '||' + self.commit_hash
    priority_list = [self.cve_id, self.vcdb_id]
    for identifier in priority_list:
      if identifier is not None:
        self.id = identifier
        return
    self.id = None

  def getSettings(self):
    parent_hash = None,
    if self.vulnerability_view:
      parent_hash = self.vulnerability_view.parent_commit

    file_provider_url = self.file_provider_url
    if file_provider_url:
      file_provider_url = self.file_provider_url.replace(
          VULN_ID_PLACEHOLDER, self.id)
    file_ref_provider_url = self.file_ref_provider_url
    if file_ref_provider_url:
      file_ref_provider_url = self.file_ref_provider_url.replace(
          VULN_ID_PLACEHOLDER, self.id)

    data = {
        'commit_link': self.commit_link,
        'commit_hash': self.commit_hash,
        'repo_url': self.repo_url,
        'repo_name': self.repo_name,
        'tree_url': url_for('vuln.vuln_file_tree', vuln_id=self.id),
        'annotation_data_url': url_for('vuln.annotation_data', vuln_id=self.id),
        'file_provider_url': file_provider_url,
        'file_ref_provider_url': file_ref_provider_url,
        'file_url': self.file_url,
        'id': self.id,
        'parent_hash': parent_hash,
        'HASH_PLACEHOLDER': HASH_PLACEHOLDER,
        'PATH_PLACEHOLDER': PATH_PLACEHOLDER,
    }
    if self.vulnerability_view.annotated:
      master_commit = self.getMasterCommit()
      if master_commit:
        files_schema = RepositoryFilesSchema(many=True)
        # TODO: Consider refactoring this section. We currently also fetch
        #  custom data from the backend.
        # Hack to quickly retrieve the full data.
        data['custom_data'] = json.loads(
            files_schema.jsonify(master_commit.repository_files).data)
    #if request.path == ''

    return data

  def hasCustomData(self):
    master_commit = self.getMasterCommit()
    if not master_commit or not master_commit.repository_files:
      return False
    files_schema = RepositoryFilesSchema(many=True)
    custom_data = files_schema.dump(master_commit.repository_files).data
    return len(custom_data) > 0

  def getMasterCommit(self):
    if not self._vulnerability:
      return None
    return self._vulnerability.master_commit

  def _init_repo_data(self):
    if self.commit_link and 'github.com' in self.commit_link:
      resource_url = self.commit_link
    else:
      resource_url = self.repo_url if self.repo_url else self.commit_link

    logging.info('Searching VCS handler for %s', resource_url)
    if not resource_url:
      return False

    vcs_handler = getVcsHandler(current_app, resource_url)
    if not vcs_handler:
      raise InvalidIdentifierException('Please provide a valid resource link.')
    self.repo_name = vcs_handler.repo_name
    self.file_provider_url = vcs_handler.getFileProviderUrl()
    self.file_ref_provider_url = vcs_handler.getRefFileProviderUrl()
    self.file_url = vcs_handler.getFileUrl()
    self.tree_url = vcs_handler.getTreeUrl()
    self.commit_hash = (
        self.commit_hash if self.commit_hash else vcs_handler.commit_hash)
    if not self.commit_hash:
      raise InvalidIdentifierException(
          'Couldn\'t extract commit hash from given resource URL.')
    return True

  @staticmethod
  def is_cve_id(id):
    # strict check from https://cve.mitre.org/cve/identifiers/tech-guidance.html
    return re.match(r'^CVE-\d{4}-(0\d{3}|[1-9]\d{3,})$', id,
                    re.IGNORECASE) is not None

  @staticmethod
  def is_vcdb_id(id):
    return id.isdigit()

  @staticmethod
  def is_commit_link(id):
    return re.match(r'.*github.com/', id) is not None

  @staticmethod
  def is_repo_data(id):
    return '||' in id

  def _fetch_by_id(self):
    if self.vcdb_id:
      self._vulnerability = Vulnerability.query.get(self.vcdb_id)
    elif self.cve_id:
      if not self.is_cve_id(self.cve_id):
        raise InvalidIdentifierException('Please provide a valid CVE ID.')
      self._vulnerability = Vulnerability.get_by_cve_id(self.cve_id)

  def _fetch_by_commit_hash(self):
    if self._vulnerability or not self.commit_hash:
      return
    self._vulnerability = Vulnerability.get_by_commit_hash(self.commit_hash)

  def fetch_tree_cache(self, skip_errors=True, max_timeout=4):
    """
    (Pre)fetches / updates the file tree cache if required and possible.
    """
    master_commit = self.getMasterCommit()
    if master_commit and not master_commit.tree_cache:
      # Fetch the required data from our VCS proxy.
      try:
        # Fetch the required data from our VCS proxy.
        proxy_target = cfg.GCE_VCS_PROXY_URL + url_for(
            'vcs_proxy.main_api',
            commit_link=self.commit_link,
            commit_hash=self.commit_hash,
            repo_url=self.repo_url)[1:]
        result = urllib2.urlopen(proxy_target, timeout=max_timeout)
        master_commit.tree_cache = result.read()
        self.update_details()
      except Exception as e:
        if not skip_errors:
          raise e

  def _fetch_data(self):
    self._fetch_by_id()
    commit = self.getMasterCommit()
    if commit:
      commit = self._vulnerability.master_commit
      self.commit_link = commit.commit_link
      self.commit_hash = commit.commit_hash
      self.repo_url = commit.repo_url
    # Initialize VCS specific data.
    self._init_repo_data()
    self._fetch_by_commit_hash()

    # Fetch IDs from vulnerability entry if it exists.
    if self._vulnerability:
      self.vcdb_id = str(self._vulnerability.id)
      if self._vulnerability.cve_id:
        self.cve_id = self._vulnerability.cve_id

    # Fetch corresponding NVD data if possible.
    if self.cve_id:
      self._nvd_data = Nvd.get_by_cve_id(self.cve_id)
    elif self.commit_hash:
      self._nvd_data = Nvd.get_by_commit_hash(self.commit_hash)

    # Make sure to always use the properly formatted CVE-ID if available.
    if self._nvd_data:
      self.cve_id = self._nvd_data.cve_id

    if self._vulnerability or self._nvd_data:
      self.fetch_tree_cache()
      self.vulnerability_view = VulnerabilityView(self._vulnerability,
                                                  self._nvd_data)

  def get_or_create_vulnerability(self):
    if self._vulnerability:
      return self._vulnerability

    default_cve_id = None
    if self._nvd_data is not None:
      default_cve_id = self._nvd_data.cve_id
    return Vulnerability(
        cve_id=default_cve_id,
        commits=[
            VulnerabilityGitCommits(
                commit_link=self.commit_link,
                repo_name=self.repo_name,
                repo_url=self.repo_url,
                commit_hash=self.commit_hash)
        ],
        comment='',
        creator=g.user,
    )
